#####
##### Purpose: Replicating the Alternative Specifications Section 
#####          of the Online Appendix 
#####

###
### Setting up the Space: Packages and Functions
###

library(tidyverse)
library(data.table)
library(dtplyr)
library(tidytext)
library(cowplot)
library(ggbeeswarm)

# function to run cross-validation on out of sample prediction
# function takes:
# data (tibble), 
# character strings for outcome variable, column with text, column with docid, 
# character vector of terms to compound
# logical for whether to stem (default is true)
# numeric vector for number of n-grams to tokenize to (default is null, but we use in the final model)
# charater string for date column, if including dates (default is null, but we use in the final model)
# character strings for volume and weight variables, if including (default is null, we don't use in the final model)
# character string for type of variable importance to store in ranger (default is "none")
# number for minimum pct of docs a term has to appear in to be included in dfm (default is 1%)
# number for proportion of docs to be included in training set (default is 70%)
# number for number of folds in cross-validation (default is 15 -- five runs on three cores)
# seed  
text_to_pred_multifolds <- function(dat, outvar,
                                    text_col, doc_col,
                                    to_compound,
                                    stem = TRUE,
                                    ngrams = NULL,
                                    date_var = NULL,
                                    volume_var = NULL,
                                    weightvar = NULL,
                                    store_importance = "none",
                                    term_min = .01, 
                                    trainshare = .7,
                                    n_folds = 15,
                                    seed = 11111){
  require(tidyverse)
  require(quanteda)
  require(distrom)
  require(textir)
  
  require(foreach)
  require(doMC)
  #require(ranger)
  #require(PRROC)
  #require(lubridate)
  #require(stringr)
  
  registerDoMC(cores = detectCores() - 1)
  
  set.seed(seed)
  
  # define index
  dat$index <- dat %>% pull(doc_col)
  
  message("Making corpus")
  corp <- quanteda::corpus(dat, text_field= text_col, docid_field= "index")
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  
  # tokenize/preprocess/remove extraneous twitter stuff
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_numbers = T,
                          remove_url=T) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    message("Stemming...")
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  if(!is.null(ngrams)){
    message(paste0("ngramming to n = ", ngrams, "..."))
    
    # ngram
    tok <- tok %>%  quanteda::tokens_ngrams(ngrams)
  }
  
  # make full document frequency matrix
  dfm.full <-  dfm(tok,  verbose=T,tolower = T)
  
  # discard very frequent/infrequent terms
  message("Trimming dfm...")
  
  dfm.lim <- dfm_trim(
    dfm.full, 
    min_docfreq = term_min, 
    max_docfreq = .5,
    docfreq_type = "prop")
  
  message(paste0("Retained ", ncol(dfm.lim), " tokens for analysis..."))
  
  if(!is.null(weightvar)){
    message("Weighting dfm...")
    for(i in rownames(dfm.lim)){
      weight <- dat[which(as.character(dat$index) == i),which(names(dat) == weightvar)] %>% as.numeric()
      
      dfm.lim[which(rownames(dfm.lim) == i),] <- dfm.lim[which(rownames(dfm.lim) == i),]*log(weight + 1)
    }
  }
  
  message("Prepping IV and DV...")
  
  # put back in matrix
  X <- as.matrix(dfm.lim)
  
  if(!is.null(date_var)){
    # make date variable and append to text features
    dates <- dat[match(rownames(X), dat$index), date_var]
    
    X <- cbind(X, dates)
  }
  
  if(!is.null(volume_var)){
    # make volume variable and append to text features
    volume <- dat[match(rownames(X), dat$index), volume_var]
    
    X <- cbind(X, volume)
  }
  
  # make outcome vector
  Y<- dat[match(rownames(X), dat$index), outvar] %>% pull(outvar)
  
  # discard empty docs
  drops <- which(rowSums(X)==0)
  
  if(length(drops) > 0){
    Y <- Y[-as.numeric(drops)]
    X <- X[-as.numeric(drops), ]
  }
  
  # make train/test splits
  pred_routine <- foreach::foreach(i = 1:n_folds) %dopar% {
    
    message(paste0("Splitting train and test sets for fold number", i,"..."))
    
    which_train <- sample(1:length(Y), floor(length(Y)*trainshare), replace = F)
    
    # define IVs and DV for training and test sets
    X_train <- X[which_train,]
    Y_train <- Y[which_train]
    
    X_test <- X[-which_train,]
    Y_test <- Y[-which_train]
    
    # bind training outcome/features
    trainbind <- cbind(Y_train, X_train)
    
    message(paste0("Modeling fold number", i,"..."))
    
    # model
    rf1 <- ranger::ranger(data= trainbind,
                          dependent.variable.name = "Y_train",
                          classification = TRUE,
                          probability = TRUE,
                          importance = store_importance,
                          verbose = FALSE)
    
    message(paste0("Predicting test set for fold number ", i,"..."))
    
    # generate predictions
    pt <- dat %>% filter(index %in% rownames(X_test)) %>%
      
      # take first column of predicted probabilities
      mutate(preds = predict(rf1, X_test)$predictions[,1]) %>%
      mutate(classpred = ifelse(preds >= .5, "R", "D")) %>%
      mutate(correct = (classpred == caucus),
             fold = i)
    
    # if the classes got flipped in RF prediction, flip back
    if(mean(pt$correct) < .5){
      pt <- pt %>% mutate(preds = 1-preds) %>%
        mutate(classpred = ifelse(preds >= .5, "R", "D")) %>%
        mutate(correct = (classpred == caucus))
    }
    
    # take measures of fit
    # overall accuracy, area under roc curve, and area under precision-recall curve for current fold
    aucroc <- PRROC::roc.curve(scores.class0 = as.numeric(pt$caucus == "R"),
                               weights.class0  = pt$preds)$auc
    
    auprc <- PRROC::pr.curve(scores.class0 = as.numeric(pt$caucus == "R"),
                             weights.class0  = pt$preds)$auc.integral
    
    outmat <- data.frame(acc = mean(pt$correct),
                         roc = aucroc,
                         prc = auprc,
                         fold = i)
    
    # if using variable importance, pull that out too
    if(!store_importance == "none"){
      imp <- rf1$variable.importance
    }else{
      imp <- NULL
    }
    
    return(list(fit_results = outmat,
                test_set_data = pt,
                importance = imp))
  }
  
  # stack measures of fit across folds
  outmat <- bind_rows(lapply(pred_routine, function(x){
    return(x$fit_results)
  }))
  
  test_sets <- bind_rows(lapply(pred_routine, function(x){
    return(x$test_set_data)
  }))
  
  # if storing importance, pull that out
  if(!store_importance == "none"){
    imp <- lapply(pred_routine, function(x){
      x$importance
    })
  }else{
    imp <- NULL
  }
  
  message("Done!")
  return(list(fit_results = outmat,
              preds = test_sets,
              importance = imp))
}

###
### This section runs the learner on the raw data. Skip this section if you do not have the 
### text of the tweets. 
###

{
  #
  # Neutral Group 1
  #
  
  # Read in the Coivd-19 tweets by members of Congress. 
  covid.tweets.nrt <- data.table::fread("covid_tweets_nrt_4120.csv")
  
  covid.tweets.nrt[,week := lubridate::week(date)]
  covid.tweets.nrt[,month := lubridate::month(date)]
  
  # generalize hashtags starting with 'china' to 'china_hashtag'
  covid.tweets.nrt[,lowtext := stringr::str_replace_all(lowtext, "\\s(#china[\\w]+)", " china_hashtag")]
  
  covid.tweets.nrt <- as_tibble(covid.tweets.nrt)
  
  # Subsetting the data to only the relevant dates and to only those that were flagged as a part of group 1.
  covid.tweets.nrt = subset(covid.tweets.nrt, covid.tweets.nrt$date<"2020-04-01")
  covid.tweets.neutral = subset(covid.tweets.nrt, covid.tweets.nrt$in.group1==T) #0.7557
  covid.tweets.neutral2 = subset(covid.tweets.nrt, 
                                 (covid.tweets.nrt$in.group1+
                                    covid.tweets.nrt$in.group5+
                                    covid.tweets.nrt$in.group7)>0) # 0.8422
  covid.tweets.neutral3 = subset(covid.tweets.nrt, 
                                 (covid.tweets.nrt$in.group1+
                                    covid.tweets.nrt$in.group5+
                                    covid.tweets.nrt$in.group7+
                                    covid.tweets.nrt$in.group3+
                                    covid.tweets.nrt$in.group4)>0) # 0.9475
  # Running the learning process. 
  neutral_cv <- text_to_pred_multifolds(dat = covid.tweets.neutral %>% filter(date<"2020-04-01"),
                                        outvar = "is_republican",
                                        text_col = "lowtext",
                                        doc_col = "index",
                                        to_compound = c("green new deal","nuclear power","cap and trade","clean coal", "gun violence","national security",
                                                        "tax cut","cut tax","cut taxes", "president trump", "american people", "trump administration",
                                                        "medicare for all","health care", "health insurance",
                                                        "single payer", "assault weapon","assault weapons","semi automatic","second amendment","brady bill",
                                                        "high capacity magainze","high capacity magazines","bump stock","bump stocks","background check",
                                                        "background checks","build the wall","wall funding","birthright citizenship","nuclear deal",
                                                        "pro life","pro choice","anti choice","born alive","partial birth","late term",
                                                        "house democrats","house republicans","senate democrats","senate republicans",
                                                        "majority leader","minority leader","town hall","donald trump","social security",
                                                        "law enforcement","preexisting conditions",
                                                        
                                                        # covid compounds
                                                        "world health organization","centers for disease control",
                                                        "supply chain","shelter in place","defense production act",
                                                        "personal protective equipment","laid off", "town hall",
                                                        "wish list","op ed"),
                                        term_min = 100/nrow(covid.tweets.neutral),
                                        date_var = "day_relative_1120",
                                        ngrams = c(1:3))
  
  # Saving Output.
  save(neutral_cv,file = "neutralcv_model.RData")
  
  sink()
  
  data.table::fwrite(neutral_cv$preds, file = "neutralcv_preds.csv")
  data.table::fwrite(neutral_cv$fit_results, file = "neutralcv_cvfit.csv")
  
  ## select the text and group it. 
  document_data <- covid.tweets.neutral %>% 
    dplyr::select(date, state_abbrev, lowtext, caucus, in.group2) %>% 
    group_by(date) %>% 
    data.frame()
  
  ## Total tweets by each party
  democratic_tweets <- sum(covid.tweets.neutral$caucus == "D")
  republican_tweets <- sum(covid.tweets.neutral$caucus == "R")
  
  ## Unnest all words over time
  all_words <- document_data %>% 
    unnest_tokens(word, lowtext) %>%
    mutate(
      ## remove the punctuation in words
      word = gsub("[[:punct:]]", "", word)
    ) %>% 
    mutate(
      ## create a unique id by party and words
      id = paste0(caucus, "@", word)
    ) %>%
    filter(
      ## drop nondescript words
      !(word %in% c("I", "your", "you", "19", "1", "if",
                    "2", "for", "in", "this", "as", "or", "an",
                    "be", "more", "my",  "will", "it", "has", "im",
                    "from", "that", "there", "their", "here",
                    "is", "are", "have", "about", "at", "do", "not",
                    "our", "with", "on", "can", "the", "by", "they",
                    "to", "and", "of", "a", "i", "we", "us"))
    ) %>%
    plyr::ddply(
      ## count of words by party
      ~id,
      summarize,
      count = n()) %>% 
    ## sort by count
    arrange(
      -count
    ) 
  
  ## pull out the party and word from the id
  all_words_id <- unlist(strsplit(all_words$id, split = "@"))
  all_words$caucus <- apply(as.matrix(as.character(all_words$id)),1, function(x){strsplit(x,split="@")[[1]][1]})
  all_words$word <- apply(as.matrix(as.character(all_words$id)),1, function(x){strsplit(x,split="@")[[1]][2]})
  
  ## pull out the Democrats
  all_words_dem <- all_words %>% 
    filter(caucus == "D") %>% 
    arrange(-count) %>% 
    dplyr::select(-id)
  
  ## pull out the Republicans
  all_words_rep <- all_words %>% 
    filter(caucus == "R") %>% 
    arrange(-count) %>% 
    dplyr::select(-id)
  
  ## rename the variables for merging later
  all_words_dem$dem_count <- all_words_dem$count
  all_words_rep$rep_count <- all_words_rep$count
  
  ## keep the variables for merging
  all_words_dem <- dplyr::select(
    all_words_dem,
    c(
      word, dem_count
    )
  )
  
  all_words_rep <- dplyr::select(
    all_words_rep,
    c(
      word, rep_count
    )
  )
  
  ## Merge the frequent Ds and Rs words
  all_words_compare <- all_words_dem %>% 
    full_join(
      all_words_rep,
      by = "word"
    ) %>% 
    mutate(
      rep_count = tidyr::replace_na(rep_count, 0),
      dem_count = tidyr::replace_na(dem_count, 0),
      pol_diff  = rep_count/republican_tweets - dem_count/democratic_tweets,
      party     = ifelse(pol_diff > 0, "Republican", "Democrat"),
      party     = factor(party, levels = c("Republican", "Democrat"))
    ) %>% 
    arrange(
      pol_diff
    )
  
  ## most different words in top 15
  all_words_compare <- all_words_compare[c(1:15, (nrow(all_words_compare) - 14):nrow(all_words_compare)),]
  
  save(all_words_compare,file = "NeutralDiffInUse.rda")
  
  #
  # Neutral Group 2
  #
  
  # Running the learning process. 
  neutral_cv2 <- text_to_pred_multifolds(dat = covid.tweets.neutral2 %>% filter(date<"2020-04-01"),
                                         outvar = "is_republican",
                                         text_col = "lowtext",
                                         doc_col = "index",
                                         to_compound = c("green new deal","nuclear power","cap and trade","clean coal", "gun violence","national security",
                                                         "tax cut","cut tax","cut taxes", "president trump", "american people", "trump administration",
                                                         "medicare for all","health care", "health insurance",
                                                         "single payer", "assault weapon","assault weapons","semi automatic","second amendment","brady bill",
                                                         "high capacity magainze","high capacity magazines","bump stock","bump stocks","background check",
                                                         "background checks","build the wall","wall funding","birthright citizenship","nuclear deal",
                                                         "pro life","pro choice","anti choice","born alive","partial birth","late term",
                                                         "house democrats","house republicans","senate democrats","senate republicans",
                                                         "majority leader","minority leader","town hall","donald trump","social security",
                                                         "law enforcement","preexisting conditions",
                                                         
                                                         # covid compounds
                                                         "world health organization","centers for disease control",
                                                         "supply chain","shelter in place","defense production act",
                                                         "personal protective equipment","laid off", "town hall",
                                                         "wish list","op ed"),
                                         term_min = 100/nrow(covid.tweets.neutral2),
                                         date_var = "day_relative_1120",
                                         ngrams = c(1:3))
  
  # Saving Output.
  save(neutral_cv2,file = "neutralcv_model2.RData")
  
  sink()
  
  data.table::fwrite(neutral_cv2$preds, file = "neutralcv_preds2.csv")
  data.table::fwrite(neutral_cv2$fit_results, file = "neutralcv_cvfit2.csv")
  
  #
  # Neutral Group 3
  #
  
  # Running the learning process. 
  neutral_cv3 <- text_to_pred_multifolds(dat = covid.tweets.neutral3 %>% filter(date<"2020-04-01"),
                                         outvar = "is_republican",
                                         text_col = "lowtext",
                                         doc_col = "index",
                                         to_compound = c("green new deal","nuclear power","cap and trade","clean coal", "gun violence","national security",
                                                         "tax cut","cut tax","cut taxes", "president trump", "american people", "trump administration",
                                                         "medicare for all","health care", "health insurance",
                                                         "single payer", "assault weapon","assault weapons","semi automatic","second amendment","brady bill",
                                                         "high capacity magainze","high capacity magazines","bump stock","bump stocks","background check",
                                                         "background checks","build the wall","wall funding","birthright citizenship","nuclear deal",
                                                         "pro life","pro choice","anti choice","born alive","partial birth","late term",
                                                         "house democrats","house republicans","senate democrats","senate republicans",
                                                         "majority leader","minority leader","town hall","donald trump","social security",
                                                         "law enforcement","preexisting conditions",
                                                         
                                                         # covid compounds
                                                         "world health organization","centers for disease control",
                                                         "supply chain","shelter in place","defense production act",
                                                         "personal protective equipment","laid off", "town hall",
                                                         "wish list","op ed"),
                                         term_min = 100/nrow(covid.tweets.neutral3),
                                         date_var = "day_relative_1120",
                                         ngrams = c(1:3))
  
  # Saving Output.
  save(neutral_cv3,file = "neutralcv_model2.RData")
  
  sink()
  
  data.table::fwrite(neutral_cv3$preds, file = "neutralcv_preds3.csv")
  data.table::fwrite(neutral_cv3$fit_results, file = "neutralcv_cvfit3.csv")
}

###
### Fit statistics
###

apply(read.csv("neutralcv_cvfit.csv"),2,mean)
apply(read.csv("neutralcv_cvfit2.csv"),2,mean)
apply(read.csv("neutralcv_cvfit3.csv"),2,mean)